{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4f32eb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Execute this cell to install dependencies\n",
    "%pip install sf-hamilton[visualization]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c1fa180",
   "metadata": {},
   "source": [
    "# Hamilton for ML dataflows [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dagworks-inc/hamilton/blob/main/examples/model_examples/scikit-learn/Hamilton_for_ML_dataflows.ipynb) [![GitHub badge](https://img.shields.io/badge/github-view_source-2b3137?logo=github)](https://github.com/apache/hamilton/blob/main/examples/model_examples/scikit-learn/Hamilton_for_ML_dataflows.ipynb)\n",
    "\n",
    "\n",
    "#### Requirements:\n",
    "\n",
    "- Install dependencies (listed in `requirements.txt`)\n",
    "\n",
    "More details [here](https://github.com/apache/hamilton/blob/main/examples/model_examples/scikit-learn/README.md#using-hamilton-for-ml-dataflows).\n",
    "\n",
    "***\n",
    "\n",
    "Uncomment and run the cell below if you are in a Google Colab environment. It will:\n",
    "1. Mount google drive. You will be asked to authenticate and give permissions.\n",
    "2. Change directory to google drive.\n",
    "3. Make a directory \"hamilton-tutorials\"\n",
    "4. Change directory to it.\n",
    "5. Clone this repository to your google drive\n",
    "6. Move your current directory to the hello_world example\n",
    "7. Install requirements.\n",
    "\n",
    "This means that any modifications will be saved, and you won't lose them if you close your browser."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c498283",
   "metadata": {},
   "outputs": [],
   "source": [
    "## 1. Mount google drive\n",
    "# from google.colab import drive\n",
    "# drive.mount('/content/drive')\n",
    "## 2. Change directory to google drive.\n",
    "# %cd /content/drive/MyDrive\n",
    "## 3. Make a directory \"hamilton-tutorials\"\n",
    "# !mkdir hamilton-tutorials\n",
    "## 4. Change directory to it.\n",
    "# %cd hamilton-tutorials\n",
    "## 5. Clone this repository to your google drive\n",
    "# !git clone https://github.com/apache/hamilton/\n",
    "## 6. Move your current directory to the hello_world example\n",
    "# %cd hamilton/examples/hello_world\n",
    "## 7. Install requirements.\n",
    "# %pip install -r requirements.txt\n",
    "# clear_output()  # optionally clear outputs\n",
    "# To check your current working directory you can type `!pwd` in a cell and run it."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b44732ae",
   "metadata": {},
   "source": [
    "***\n",
    "Here we have a simple example showing how you can write a ML training and evaluation workflow with Hamilton. \n",
    "***"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "817c2996",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Example script showing how one might setup a generic model training pipeline that is quickly configurable.\n",
    "\"\"\"\n",
    "\n",
    "import digit_loader\n",
    "import iris_loader\n",
    "import my_train_evaluate_logic\n",
    "\n",
    "from hamilton import base, driver\n",
    "\n",
    "\n",
    "def get_data_loader(data_set: str):\n",
    "    \"\"\"Returns the module to load that will procur data -- the data loaders all have to define the same functions.\"\"\"\n",
    "    if data_set == \"iris\":\n",
    "        return iris_loader\n",
    "    elif data_set == \"digits\":\n",
    "        return digit_loader\n",
    "    else:\n",
    "        raise ValueError(f\"Unknown data_name {data_set}.\")\n",
    "\n",
    "\n",
    "def get_model_config(model_type: str) -> dict:\n",
    "    \"\"\"Returns model type specific configuration\"\"\"\n",
    "    if model_type == \"svm\":\n",
    "        return {\"clf\": \"svm\", \"gamma\": 0.001}\n",
    "    elif model_type == \"logistic\":\n",
    "        return {\"clf\": \"logistic\", \"penalty\": \"l2\"}\n",
    "    else:\n",
    "        raise ValueError(f\"Unsupported model {model_type}.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dec9f8b4",
   "metadata": {},
   "source": [
    "***\n",
    "For the purpose of this experiment, lets apply the following configuration:\n",
    "\n",
    "- `_data_set` = 'digits'\n",
    "- `_model_type` = 'logistic'\n",
    "\n",
    "More details [here](https://github.com/apache/hamilton/blob/main/examples/model_examples/scikit-learn/README.md).\n",
    "***"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "fbc5669e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Note: Hamilton collects completely anonymous data about usage. This will help us improve Hamilton over time. See https://github.com/apache/hamilton#usage-analytics--data-privacy for details.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "classification_report :\n",
      "               precision    recall  f1-score   support\n",
      "\n",
      "           0       1.00      0.99      0.99        91\n",
      "           1       0.92      0.95      0.94        84\n",
      "           2       0.98      1.00      0.99        83\n",
      "           3       0.99      0.98      0.98        81\n",
      "           4       0.95      0.99      0.97        95\n",
      "           5       0.98      0.94      0.96        97\n",
      "           6       0.97      0.98      0.97        85\n",
      "           7       0.98      0.98      0.98        96\n",
      "           8       0.91      0.90      0.91        96\n",
      "           9       0.96      0.93      0.94        91\n",
      "\n",
      "    accuracy                           0.96       899\n",
      "   macro avg       0.96      0.96      0.96       899\n",
      "weighted avg       0.96      0.96      0.96       899\n",
      "\n",
      "confusion_matrix :\n",
      " [[90  0  0  0  1  0  0  0  0  0]\n",
      " [ 0 80  0  0  1  0  1  0  2  0]\n",
      " [ 0  0 83  0  0  0  0  0  0  0]\n",
      " [ 0  0  0 79  0  0  0  1  0  1]\n",
      " [ 0  1  0  0 94  0  0  0  0  0]\n",
      " [ 0  1  0  1  1 91  0  1  0  2]\n",
      " [ 0  0  0  0  0  0 83  0  2  0]\n",
      " [ 0  0  0  0  1  0  0 94  0  1]\n",
      " [ 0  5  2  0  0  1  2  0 86  0]\n",
      " [ 0  0  0  0  1  1  0  0  4 85]]\n",
      "fit_clf :\n",
      " LogisticRegression()\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/flaviassantos/github/hamilton/venv/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n"
     ]
    }
   ],
   "source": [
    "_data_set = 'digits'  # the data set to load\n",
    "_model_type = 'logistic'  # the model type to fit and evaluate with\n",
    "\n",
    "dag_config = {\n",
    "    \"test_size_fraction\": 0.5,\n",
    "    \"shuffle_train_test_split\": True,\n",
    "}\n",
    "# augment config\n",
    "dag_config.update(get_model_config(_model_type))\n",
    "# get module with functions to load data\n",
    "data_module = get_data_loader(_data_set)\n",
    "# set the desired result container we want\n",
    "adapter = base.DefaultAdapter()\n",
    "\"\"\"\n",
    "What's cool about this, is that by simply changing the `dag_config` and the `data_module` we can\n",
    "reuse the logic in the `my_train_evaluate_logic` module very easily for different contexts and purposes if\n",
    "want to setup a generic model fitting and prediction dataflow!\n",
    "E.g. if we want to support a new data set, then we just need to add a new data loading module.\n",
    "E.g. if we want to support a new model type, then we just need to add a single conditional function\n",
    "     to my_train_evaluate_logic.\n",
    "\"\"\"\n",
    "dr = driver.Driver(dag_config, data_module, my_train_evaluate_logic, adapter=adapter)\n",
    "# ensure you have done \"pip install \"sf-hamilton[visualization]\"\" for the following to work:\n",
    "# dr.visualize_execution(['classification_report', 'confusion_matrix', 'fit_clf'],\n",
    "#                        f'./model_dag_{_data_set}_{_model_type}.dot', {\"format\": \"png\"})\n",
    "results = dr.execute([\"classification_report\", \"confusion_matrix\", \"fit_clf\"])\n",
    "for k, v in results.items():\n",
    "    print(k, \":\\n\", v)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc15cd9c",
   "metadata": {},
   "source": [
    "***\n",
    "Here is the graph of execution for the digits data set:\n",
    "***"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "060abd35",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 9.0.0 (20230911.1827)\n",
       " -->\n",
       "<!-- Pages: 1 -->\n",
       "<svg width=\"729pt\" height=\"548pt\"\n",
       " viewBox=\"0.00 0.00 729.01 548.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 544)\">\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-544 725.01,-544 725.01,4 -4,4\"/>\n",
       "<!-- confusion_matrix -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>confusion_matrix</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"350.52,-36 238.52,-36 238.52,0 350.52,0 350.52,-36\"/>\n",
       "<text text-anchor=\"middle\" x=\"294.52\" y=\"-12.95\" font-family=\"Times,serif\" font-size=\"14.00\">confusion_matrix</text>\n",
       "</g>\n",
       "<!-- classification_report -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>classification_report</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"200.65,-36 74.4,-36 74.4,0 200.65,0 200.65,-36\"/>\n",
       "<text text-anchor=\"middle\" x=\"137.52\" y=\"-12.95\" font-family=\"Times,serif\" font-size=\"14.00\">classification_report</text>\n",
       "</g>\n",
       "<!-- y_train -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>y_train</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"414.52\" cy=\"-306\" rx=\"37.02\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"414.52\" y=\"-300.95\" font-family=\"Times,serif\" font-size=\"14.00\">y_train</text>\n",
       "</g>\n",
       "<!-- fit_clf -->\n",
       "<g id=\"node11\" class=\"node\">\n",
       "<title>fit_clf</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"403.52,-252 349.52,-252 349.52,-216 403.52,-216 403.52,-252\"/>\n",
       "<text text-anchor=\"middle\" x=\"376.52\" y=\"-228.95\" font-family=\"Times,serif\" font-size=\"14.00\">fit_clf</text>\n",
       "</g>\n",
       "<!-- y_train&#45;&gt;fit_clf -->\n",
       "<g id=\"edge18\" class=\"edge\">\n",
       "<title>y_train&#45;&gt;fit_clf</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M405.52,-288.41C401.29,-280.62 396.14,-271.14 391.36,-262.33\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"394.56,-260.9 386.72,-253.78 388.41,-264.24 394.56,-260.9\"/>\n",
       "</g>\n",
       "<!-- y_test -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>y_test</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"154.52\" cy=\"-234\" rx=\"32.93\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"154.52\" y=\"-228.95\" font-family=\"Times,serif\" font-size=\"14.00\">y_test</text>\n",
       "</g>\n",
       "<!-- y_test_with_labels -->\n",
       "<g id=\"node13\" class=\"node\">\n",
       "<title>y_test_with_labels</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"97.52\" cy=\"-90\" rx=\"80.01\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"97.52\" y=\"-84.95\" font-family=\"Times,serif\" font-size=\"14.00\">y_test_with_labels</text>\n",
       "</g>\n",
       "<!-- y_test&#45;&gt;y_test_with_labels -->\n",
       "<g id=\"edge20\" class=\"edge\">\n",
       "<title>y_test&#45;&gt;y_test_with_labels</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M136.39,-218.82C125.44,-209.15 112.34,-195.35 105.52,-180 97.09,-161.02 95.29,-137.54 95.49,-119.47\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"98.98,-119.84 95.83,-109.72 91.99,-119.6 98.98,-119.84\"/>\n",
       "</g>\n",
       "<!-- feature_matrix -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>feature_matrix</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"436.52\" cy=\"-450\" rx=\"65.68\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"436.52\" y=\"-444.95\" font-family=\"Times,serif\" font-size=\"14.00\">feature_matrix</text>\n",
       "</g>\n",
       "<!-- train_test_split_func -->\n",
       "<g id=\"node10\" class=\"node\">\n",
       "<title>train_test_split_func</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"320.52\" cy=\"-378\" rx=\"86.67\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"320.52\" y=\"-372.95\" font-family=\"Times,serif\" font-size=\"14.00\">train_test_split_func</text>\n",
       "</g>\n",
       "<!-- feature_matrix&#45;&gt;train_test_split_func -->\n",
       "<g id=\"edge12\" class=\"edge\">\n",
       "<title>feature_matrix&#45;&gt;train_test_split_func</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M410.2,-433.12C394.67,-423.74 374.76,-411.73 357.65,-401.4\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"359.62,-398.5 349.25,-396.33 356,-404.5 359.62,-398.5\"/>\n",
       "</g>\n",
       "<!-- target -->\n",
       "<g id=\"node6\" class=\"node\">\n",
       "<title>target</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"320.52\" cy=\"-450\" rx=\"31.9\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"320.52\" y=\"-444.95\" font-family=\"Times,serif\" font-size=\"14.00\">target</text>\n",
       "</g>\n",
       "<!-- target&#45;&gt;train_test_split_func -->\n",
       "<g id=\"edge13\" class=\"edge\">\n",
       "<title>target&#45;&gt;train_test_split_func</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M320.52,-431.7C320.52,-424.41 320.52,-415.73 320.52,-407.54\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"324.02,-407.62 320.52,-397.62 317.02,-407.62 324.02,-407.62\"/>\n",
       "</g>\n",
       "<!-- test_size_fraction -->\n",
       "<g id=\"node7\" class=\"node\">\n",
       "<title>test_size_fraction</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"620.52\" cy=\"-450\" rx=\"100.48\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"620.52\" y=\"-444.95\" font-family=\"Times,serif\" font-size=\"14.00\">Input: test_size_fraction</text>\n",
       "</g>\n",
       "<!-- test_size_fraction&#45;&gt;train_test_split_func -->\n",
       "<g id=\"edge14\" class=\"edge\">\n",
       "<title>test_size_fraction&#45;&gt;train_test_split_func</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M560.89,-435.09C510.99,-423.44 439.8,-406.83 387.85,-394.71\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"388.75,-391.33 378.22,-392.46 387.16,-398.14 388.75,-391.33\"/>\n",
       "</g>\n",
       "<!-- X_test -->\n",
       "<g id=\"node8\" class=\"node\">\n",
       "<title>X_test</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"277.52\" cy=\"-234\" rx=\"34.97\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"277.52\" y=\"-228.95\" font-family=\"Times,serif\" font-size=\"14.00\">X_test</text>\n",
       "</g>\n",
       "<!-- predicted_output -->\n",
       "<g id=\"node19\" class=\"node\">\n",
       "<title>predicted_output</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"327.52\" cy=\"-162\" rx=\"73.36\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"327.52\" y=\"-156.95\" font-family=\"Times,serif\" font-size=\"14.00\">predicted_output</text>\n",
       "</g>\n",
       "<!-- X_test&#45;&gt;predicted_output -->\n",
       "<g id=\"edge25\" class=\"edge\">\n",
       "<title>X_test&#45;&gt;predicted_output</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M289.12,-216.76C294.94,-208.61 302.15,-198.53 308.73,-189.31\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"311.45,-191.53 314.41,-181.36 305.75,-187.46 311.45,-191.53\"/>\n",
       "</g>\n",
       "<!-- predicted_output_with_labels -->\n",
       "<g id=\"node9\" class=\"node\">\n",
       "<title>predicted_output_with_labels</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"315.52\" cy=\"-90\" rx=\"120.45\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"315.52\" y=\"-84.95\" font-family=\"Times,serif\" font-size=\"14.00\">predicted_output_with_labels</text>\n",
       "</g>\n",
       "<!-- predicted_output_with_labels&#45;&gt;confusion_matrix -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>predicted_output_with_labels&#45;&gt;confusion_matrix</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M310.33,-71.7C308.09,-64.24 305.42,-55.32 302.91,-46.97\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"306.32,-46.13 300.09,-37.55 299.61,-48.14 306.32,-46.13\"/>\n",
       "</g>\n",
       "<!-- predicted_output_with_labels&#45;&gt;classification_report -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>predicted_output_with_labels&#45;&gt;classification_report</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M274.24,-72.76C249.96,-63.22 218.95,-51.02 192.6,-40.66\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"194.01,-37.45 183.42,-37.05 191.44,-43.97 194.01,-37.45\"/>\n",
       "</g>\n",
       "<!-- train_test_split_func&#45;&gt;y_train -->\n",
       "<g id=\"edge5\" class=\"edge\">\n",
       "<title>train_test_split_func&#45;&gt;y_train</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M342.8,-360.41C355.68,-350.82 372.02,-338.65 385.86,-328.35\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"387.81,-331.25 393.74,-322.47 383.63,-325.64 387.81,-331.25\"/>\n",
       "</g>\n",
       "<!-- train_test_split_func&#45;&gt;y_test -->\n",
       "<g id=\"edge6\" class=\"edge\">\n",
       "<title>train_test_split_func&#45;&gt;y_test</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M300.83,-360.15C270.84,-334.5 213.71,-285.63 180.4,-257.14\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"182.71,-254.5 172.84,-250.66 178.16,-259.82 182.71,-254.5\"/>\n",
       "</g>\n",
       "<!-- train_test_split_func&#45;&gt;X_test -->\n",
       "<g id=\"edge9\" class=\"edge\">\n",
       "<title>train_test_split_func&#45;&gt;X_test</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M299.32,-360.25C289.13,-350.83 277.97,-338.13 272.52,-324 265.07,-304.67 266.78,-281.21 270.11,-263.23\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"273.52,-264.04 272.2,-253.53 266.67,-262.57 273.52,-264.04\"/>\n",
       "</g>\n",
       "<!-- X_train -->\n",
       "<g id=\"node16\" class=\"node\">\n",
       "<title>X_train</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"320.52\" cy=\"-306\" rx=\"39.07\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"320.52\" y=\"-300.95\" font-family=\"Times,serif\" font-size=\"14.00\">X_train</text>\n",
       "</g>\n",
       "<!-- train_test_split_func&#45;&gt;X_train -->\n",
       "<g id=\"edge23\" class=\"edge\">\n",
       "<title>train_test_split_func&#45;&gt;X_train</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M320.52,-359.7C320.52,-352.41 320.52,-343.73 320.52,-335.54\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"324.02,-335.62 320.52,-325.62 317.02,-335.62 324.02,-335.62\"/>\n",
       "</g>\n",
       "<!-- fit_clf&#45;&gt;predicted_output -->\n",
       "<g id=\"edge24\" class=\"edge\">\n",
       "<title>fit_clf&#45;&gt;predicted_output</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M364.41,-215.7C358.83,-207.73 352.09,-198.1 345.91,-189.26\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"348.97,-187.53 340.36,-181.34 343.23,-191.54 348.97,-187.53\"/>\n",
       "</g>\n",
       "<!-- target_names -->\n",
       "<g id=\"node12\" class=\"node\">\n",
       "<title>target_names</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"175.52\" cy=\"-162\" rx=\"60.56\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"175.52\" y=\"-156.95\" font-family=\"Times,serif\" font-size=\"14.00\">target_names</text>\n",
       "</g>\n",
       "<!-- target_names&#45;&gt;predicted_output_with_labels -->\n",
       "<g id=\"edge11\" class=\"edge\">\n",
       "<title>target_names&#45;&gt;predicted_output_with_labels</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M205.56,-145.98C224.89,-136.32 250.34,-123.59 271.89,-112.82\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"273.37,-115.99 280.75,-108.39 270.24,-109.73 273.37,-115.99\"/>\n",
       "</g>\n",
       "<!-- target_names&#45;&gt;y_test_with_labels -->\n",
       "<g id=\"edge21\" class=\"edge\">\n",
       "<title>target_names&#45;&gt;y_test_with_labels</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M157.04,-144.41C147.47,-135.82 135.59,-125.16 124.97,-115.63\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"127.32,-113.04 117.54,-108.97 122.65,-118.25 127.32,-113.04\"/>\n",
       "</g>\n",
       "<!-- y_test_with_labels&#45;&gt;confusion_matrix -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>y_test_with_labels&#45;&gt;confusion_matrix</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M138.82,-74.33C166.68,-64.43 203.87,-51.21 234.92,-40.18\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"235.95,-43.53 244.2,-36.88 233.6,-36.93 235.95,-43.53\"/>\n",
       "</g>\n",
       "<!-- y_test_with_labels&#45;&gt;classification_report -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>y_test_with_labels&#45;&gt;classification_report</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M107.41,-71.7C111.87,-63.9 117.23,-54.51 122.19,-45.83\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"125.08,-47.84 127,-37.42 119,-44.36 125.08,-47.84\"/>\n",
       "</g>\n",
       "<!-- prefit_clf -->\n",
       "<g id=\"node14\" class=\"node\">\n",
       "<title>prefit_clf</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"514.52\" cy=\"-306\" rx=\"45.21\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"514.52\" y=\"-300.95\" font-family=\"Times,serif\" font-size=\"14.00\">prefit_clf</text>\n",
       "</g>\n",
       "<!-- prefit_clf&#45;&gt;fit_clf -->\n",
       "<g id=\"edge16\" class=\"edge\">\n",
       "<title>prefit_clf&#45;&gt;fit_clf</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M487.26,-291.17C466.28,-280.53 437.05,-265.7 413.89,-253.96\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"415.58,-250.89 405.08,-249.49 412.42,-257.13 415.58,-250.89\"/>\n",
       "</g>\n",
       "<!-- penalty -->\n",
       "<g id=\"node15\" class=\"node\">\n",
       "<title>penalty</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"514.52\" cy=\"-378\" rx=\"62.61\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"514.52\" y=\"-372.95\" font-family=\"Times,serif\" font-size=\"14.00\">Input: penalty</text>\n",
       "</g>\n",
       "<!-- penalty&#45;&gt;prefit_clf -->\n",
       "<g id=\"edge22\" class=\"edge\">\n",
       "<title>penalty&#45;&gt;prefit_clf</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M514.52,-359.7C514.52,-352.41 514.52,-343.73 514.52,-335.54\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"518.02,-335.62 514.52,-325.62 511.02,-335.62 518.02,-335.62\"/>\n",
       "</g>\n",
       "<!-- X_train&#45;&gt;fit_clf -->\n",
       "<g id=\"edge17\" class=\"edge\">\n",
       "<title>X_train&#45;&gt;fit_clf</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M333.51,-288.76C340.1,-280.53 348.27,-270.32 355.71,-261.02\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"358.37,-263.3 361.88,-253.31 352.9,-258.93 358.37,-263.3\"/>\n",
       "</g>\n",
       "<!-- shuffle_train_test_split -->\n",
       "<g id=\"node17\" class=\"node\">\n",
       "<title>shuffle_train_test_split</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"150.52\" cy=\"-450\" rx=\"120.45\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"150.52\" y=\"-444.95\" font-family=\"Times,serif\" font-size=\"14.00\">Input: shuffle_train_test_split</text>\n",
       "</g>\n",
       "<!-- shuffle_train_test_split&#45;&gt;train_test_split_func -->\n",
       "<g id=\"edge15\" class=\"edge\">\n",
       "<title>shuffle_train_test_split&#45;&gt;train_test_split_func</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M190.38,-432.59C214.88,-422.5 246.42,-409.51 272.32,-398.85\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"273.35,-402.21 281.27,-395.16 270.69,-395.74 273.35,-402.21\"/>\n",
       "</g>\n",
       "<!-- digit_data -->\n",
       "<g id=\"node18\" class=\"node\">\n",
       "<title>digit_data</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"320.52\" cy=\"-522\" rx=\"47.77\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"320.52\" y=\"-516.95\" font-family=\"Times,serif\" font-size=\"14.00\">digit_data</text>\n",
       "</g>\n",
       "<!-- digit_data&#45;&gt;feature_matrix -->\n",
       "<g id=\"edge7\" class=\"edge\">\n",
       "<title>digit_data&#45;&gt;feature_matrix</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M345.12,-506.15C361.29,-496.4 382.69,-483.48 400.73,-472.6\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"402.3,-475.74 409.05,-467.58 398.68,-469.75 402.3,-475.74\"/>\n",
       "</g>\n",
       "<!-- digit_data&#45;&gt;target -->\n",
       "<g id=\"edge8\" class=\"edge\">\n",
       "<title>digit_data&#45;&gt;target</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M320.52,-503.7C320.52,-496.41 320.52,-487.73 320.52,-479.54\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"324.02,-479.62 320.52,-469.62 317.02,-479.62 324.02,-479.62\"/>\n",
       "</g>\n",
       "<!-- digit_data&#45;&gt;target_names -->\n",
       "<g id=\"edge19\" class=\"edge\">\n",
       "<title>digit_data&#45;&gt;target_names</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M272.99,-518.88C197.89,-514.54 56.7,-501.97 21.52,-468 -7.57,-439.9 2.52,-419.45 2.52,-379 2.52,-379 2.52,-379 2.52,-305 2.52,-241.15 73.79,-200.77 124.85,-180.03\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"125.95,-183.35 133.99,-176.45 123.4,-176.83 125.95,-183.35\"/>\n",
       "</g>\n",
       "<!-- predicted_output&#45;&gt;predicted_output_with_labels -->\n",
       "<g id=\"edge10\" class=\"edge\">\n",
       "<title>predicted_output&#45;&gt;predicted_output_with_labels</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M324.56,-143.7C323.29,-136.32 321.79,-127.52 320.37,-119.25\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"323.85,-118.86 318.71,-109.6 316.95,-120.04 323.85,-118.86\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x129594590>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dr.visualize_execution(['classification_report', 'confusion_matrix', 'fit_clf'],\n",
    "                       f'./model_dag_{_data_set}_{_model_type}.dot', {\"format\": \"png\"})"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hamilton",
   "language": "python",
   "name": "hamilton"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
